
"""Predicting 3d poses from 2d joints"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import math
import h5py
import copy
import random

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from six.moves import xrange  # pylint: disable=redefined-builtin

from . import viz
from . import cameras
from . import data_utils
from . import procrustes
from . import linear_model

from . import com_path

tf.app.flags.DEFINE_float("learning_rate", 1e-3, "Learning rate")
tf.app.flags.DEFINE_float(
    "dropout", 1, "Dropout keep probability. 1 means no dropout")
tf.app.flags.DEFINE_integer(
    "batch_size", 64, "Batch size to use during training")
tf.app.flags.DEFINE_integer(
    "epochs", 200, "How many epochs we should train for")
tf.app.flags.DEFINE_boolean("camera_frame", False,
                            "Convert 3d poses to camera coordinates")
tf.app.flags.DEFINE_boolean(
    "max_norm", False, "Apply maxnorm constraint to the weights")
tf.app.flags.DEFINE_boolean("batch_norm", False, "Use batch_normalization")

# Data loading
tf.app.flags.DEFINE_boolean("predict_14", False, "predict 14 joints")
tf.app.flags.DEFINE_boolean(
    "use_sh", False, "Use 2d pose predictions from StackedHourglass")
tf.app.flags.DEFINE_string(
    "action", "All", "The action to train on. 'All' means all the actions")

# Architecture
tf.app.flags.DEFINE_integer("linear_size", 1024, "Size of each model layer.")
tf.app.flags.DEFINE_integer("num_layers", 2, "Number of layers in the model.")
tf.app.flags.DEFINE_boolean(
    "residual", False, "Whether to add a residual connection every 2 layers")

# Evaluation
tf.app.flags.DEFINE_boolean(
    "procrustes", False, "Apply procrustes analysis at test time")
tf.app.flags.DEFINE_boolean(
    "evaluateActionWise", False, "The dataset to use either h36m or heva")

# Directories
tf.app.flags.DEFINE_string("model_dir", os.path.join(os.getcwd(), "data/"), "Model directory")
tf.app.flags.DEFINE_string("data_dir", os.path.join(os.getcwd(), "data/h36m/"), "Data directory")
tf.app.flags.DEFINE_string("cameras_path", os.path.join(os.getcwd(), "data/h36m/cameras.h5"), "Directory to load camera parameters")
tf.app.flags.DEFINE_string("summaries_dir", os.path.join(os.getcwd(), "data/summaries/"), "Training Summaries directory")
# tf.app.flags.DEFINE_string("train_dir", os.path.join(com_path, "experiments/"), "Training directory.")

# openpose
tf.app.flags.DEFINE_string(
    "openpose", "openpose_output", "openpose output Data directory")
tf.app.flags.DEFINE_integer("gif_fps", 30, "output gif framerate")
tf.app.flags.DEFINE_integer(
    "verbose", 2, "0:Error, 1:Warning, 2:INFO*(default), 3:debug")
tf.app.flags.DEFINE_integer('person_idx', 1, """取得人物INDEX""")


# Train or load
tf.app.flags.DEFINE_boolean("sample", False, "Set to True for sampling.")
tf.app.flags.DEFINE_boolean("use_cpu", False, "Whether to use the CPU")
tf.app.flags.DEFINE_integer("load", 0, "Try to load a previous checkpoint.")

# Misc
tf.app.flags.DEFINE_boolean(
    "use_fp16", False, "Train using fp16 instead of fp32.")

FLAGS = tf.app.flags.FLAGS


def write_parameter_to_file(file_path):
    with open(file_path, 'w') as file_object:
        file_object.write("action = {0}\n".format(str(FLAGS.action)))
        file_object.write("dropout = {0}\n".format(str(FLAGS.dropout)))
        file_object.write("epochs = {0}\n".format(
            str(FLAGS.epochs) if FLAGS.epochs > 0 else ''))
        file_object.write("learning_rate = {0}\n".format(
            str(FLAGS.learning_rate)))
        file_object.write("residual = {0}\n".format(str(FLAGS.residual)))
        file_object.write("depth = {0}\n".format(str(FLAGS.num_layers)))
        file_object.write("linear_size = {0}\n".format(str(FLAGS.linear_size)))
        file_object.write("batch_size = {0}\n".format(str(FLAGS.batch_size)))
        file_object.write("procrustes = {0}\n".format(str(FLAGS.procrustes)))
        file_object.write("maxnorm = {0}\n".format(str(FLAGS.max_norm)))
        file_object.write(
            "batch_normalization = {0}\n".format(str(FLAGS.batch_norm)))
        file_object.write(
            "use_stacked_hourglass = {0}\n".format(str(FLAGS.use_sh)))
        file_object.write("depth = {0}\n".format(str(FLAGS.num_layers)))
        file_object.write("predict = 17")
        file_object.close()


def create_model(session, actions, batch_size):
    """
    Create model and initialize it or load its parameters in a session

    Args
      session: tensorflow session
      actions: list of string. Actions to train/test on
      batch_size: integer. Number of examples in each batch
    Returns
      model: The created (or loaded) model
    Raises
      ValueError if asked to load a model, but the checkpoint specified by
      FLAGS.load cannot be found.
    """

    model = linear_model.LinearModel(
        FLAGS.linear_size,
        FLAGS.num_layers,
        FLAGS.residual,
        FLAGS.batch_norm,
        FLAGS.max_norm,
        batch_size,
        FLAGS.learning_rate,
        FLAGS.summaries_dir,
        FLAGS.predict_14,
        dtype=tf.float16 if FLAGS.use_fp16 else tf.float32)

    # Load a previously saved model
    ckpt = tf.train.get_checkpoint_state(
        FLAGS.model_dir, latest_filename="checkpoint")

    if ckpt and ckpt.model_checkpoint_path:
        # Check if the specific checkpoint exists
        if FLAGS.load > 0:
            if os.path.isfile(os.path.join(FLAGS.model_dir, "checkpoint-{0}.index".format(FLAGS.load))):
                ckpt_name = os.path.join(os.path.join(
                    FLAGS.model_dir, "checkpoint-{0}".format(FLAGS.load)))
                # 写入配置文件
                write_parameter_to_file(file_path=os.path.join(
                    FLAGS.summaries_dir, "parameter.txt"))
            else:
                raise ValueError(
                    "Asked to load checkpoint {0}, but it does not seem to exist".format(FLAGS.load))
        else:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)

        print("Loading model {0}".format(ckpt_name))
        model.saver.restore(session, ckpt.model_checkpoint_path)
        return model
    else:
        print("Could not find checkpoint. Aborting.")
        raise(ValueError, "Checkpoint {0} does not seem to exist".format(
            ckpt.model_checkpoint_path))


def train():
    """Train a linear model for 3d pose estimation"""

    actions = data_utils.define_actions(FLAGS.action)

    number_of_actions = len(actions)

    # Load camera parameters
    SUBJECT_IDS = [1, 5, 6, 7, 8, 9, 11]
    rcams = cameras.load_cameras(FLAGS.cameras_path, SUBJECT_IDS)

    # Load 3d data and load (or create) 2d projections
    train_set_3d, test_set_3d, data_mean_3d, data_std_3d, dim_to_ignore_3d, dim_to_use_3d, train_root_positions, test_root_positions = data_utils.read_3d_data(
        actions, FLAGS.data_dir, FLAGS.camera_frame, rcams, FLAGS.predict_14)

    # Read stacked hourglass 2D predictions if use_sh, otherwise use groundtruth 2D projections
    if FLAGS.use_sh:
        train_set_2d, test_set_2d, data_mean_2d, data_std_2d, dim_to_ignore_2d, dim_to_use_2d = data_utils.read_2d_predictions(
            actions, FLAGS.data_dir)
    else:
        train_set_2d, test_set_2d, data_mean_2d, data_std_2d, dim_to_ignore_2d, dim_to_use_2d = data_utils.create_2d_data(
            actions, FLAGS.data_dir, rcams)
    print("done reading and normalizing data.")

    # Avoid using the GPU if requested
    device_count = {"GPU": 0} if FLAGS.use_cpu else {"GPU": 1}
    with tf.Session(config=tf.ConfigProto(
            device_count=device_count,
            allow_soft_placement=True)) as sess:

        # === Create the model ===
        print("Creating %d bi-layers of %d units." %
              (FLAGS.num_layers, FLAGS.linear_size))
        model = create_model(sess, actions, FLAGS.batch_size)
        model.train_writer.add_graph(sess.graph)
        print("Model created")

        # === This is the training loop ===
        step_time, loss, val_loss = 0.0, 0.0, 0.0
        current_step = 0 if FLAGS.load <= 0 else FLAGS.load + 1
        previous_losses = []

        step_time, loss = 0, 0
        current_epoch = 0
        log_every_n_batches = 100

        for _ in xrange(FLAGS.epochs):
            current_epoch = current_epoch + 1

            # === Load training batches for one epoch ===
            encoder_inputs, decoder_outputs = model.get_all_batches(
                train_set_2d, train_set_3d, FLAGS.camera_frame, training=True)
            nbatches = len(encoder_inputs)
            print("There are {0} train batches".format(nbatches))
            start_time, loss = time.time(), 0.

            # === Loop through all the training batches ===
            for i in range(nbatches):

                if (i+1) % log_every_n_batches == 0:
                    # Print progress every log_every_n_batches batches
                    print(
                        "Working on epoch {0}, batch {1} / {2}... ".format(current_epoch, i+1, nbatches), end="")

                enc_in, dec_out = encoder_inputs[i], decoder_outputs[i]
                step_loss, loss_summary, lr_summary, _ = model.step(
                    sess, enc_in, dec_out, FLAGS.dropout, isTraining=True)

                if (i+1) % log_every_n_batches == 0:
                    # Log and print progress every log_every_n_batches batches
                    model.train_writer.add_summary(loss_summary, current_step)
                    model.train_writer.add_summary(lr_summary, current_step)
                    step_time = (time.time() - start_time)
                    start_time = time.time()
                    print("done in {0:.2f} ms".format(
                        1000*step_time / log_every_n_batches))

                loss += step_loss
                current_step += 1
                # === end looping through training batches ===

            loss = loss / nbatches
            print("=============================\n"
                  "Global step:         %d\n"
                  "Learning rate:       %.2e\n"
                  "Train loss avg:      %.4f\n"
                  "=============================" % (model.global_step.eval(),
                                                     model.learning_rate.eval(), loss))
            # === End training for an epoch ===

            # === Testing after this epoch ===
            isTraining = False

            if FLAGS.evaluateActionWise:

                # print("{0:=^12} {1:=^6}".format("Action", "mm")) # line of 30 equal signs

                cum_err = 0
                for action in actions:

                    # print("{0:<12} ".format(action), end="")
                    # Get 2d and 3d testing data for this action
                    action_test_set_2d = get_action_subset(test_set_2d, action)
                    action_test_set_3d = get_action_subset(test_set_3d, action)
                    encoder_inputs, decoder_outputs = model.get_all_batches(
                        action_test_set_2d, action_test_set_3d, FLAGS.camera_frame, training=False)

                    act_err, _, step_time, loss = evaluate_batches(sess, model,
                                                                   data_mean_3d, data_std_3d, dim_to_use_3d, dim_to_ignore_3d,
                                                                   data_mean_2d, data_std_2d, dim_to_use_2d, dim_to_ignore_2d,
                                                                   current_step, encoder_inputs, decoder_outputs)
                    cum_err = cum_err + act_err

                    # print("{0:>6.2f}".format(act_err))

                summaries = sess.run(model.err_mm_summary, {
                                     model.err_mm: float(cum_err/float(len(actions)))})
                model.test_writer.add_summary(summaries, current_step)
                # print("{0:<12} {1:>6.2f}".format("Average", cum_err/float(len(actions) )))
                # print("{0:=^19}".format(''))

            else:

                n_joints = 17 if not(FLAGS.predict_14) else 14
                encoder_inputs, decoder_outputs = model.get_all_batches(
                    test_set_2d, test_set_3d, FLAGS.camera_frame, training=False)

                total_err, joint_err, step_time, loss = evaluate_batches(sess, model,
                                                                         data_mean_3d, data_std_3d, dim_to_use_3d, dim_to_ignore_3d,
                                                                         data_mean_2d, data_std_2d, dim_to_use_2d, dim_to_ignore_2d,
                                                                         current_step, encoder_inputs, decoder_outputs, current_epoch)

                print("=============================\n"
                      "Step-time (ms):      %.4f\n"
                      "Val loss avg:        %.4f\n"
                      "Val error avg (mm):  %.2f\n"
                      "=============================" % (1000*step_time, loss, total_err))

                for i in range(n_joints):
                    # 6 spaces, right-aligned, 5 decimal places
                    print("Error in joint {0:02d} (mm): {1:>5.2f}".format(
                        i+1, joint_err[i]))
                print("=============================")

                # Log the error to tensorboard
                summaries = sess.run(model.err_mm_summary, {
                                     model.err_mm: total_err})
                model.test_writer.add_summary(summaries, current_step)

            # Save the model
            print("Saving the model... ", end="")
            start_time = time.time()
            model.saver.save(sess, os.path.join(
                FLAGS.model, 'checkpoint'), global_step=current_step)
            print("done in {0:.2f} ms".format(1000*(time.time() - start_time)))

            # Reset global time and loss
            step_time, loss = 0, 0

            sys.stdout.flush()


def get_action_subset(poses_set, action):
    """
    Given a preloaded dictionary of poses, load the subset of a particular action

    Args
      poses_set: dictionary with keys k=(subject, action, seqname),
        values v=(nxd matrix of poses)
      action: string. The action that we want to filter out
    Returns
      poses_subset: dictionary with same structure as poses_set, but only with the
        specified action.
    """
    return {k: v for k, v in poses_set.items() if k[1] == action}


def evaluate_batches(sess, model,
                     data_mean_3d, data_std_3d, dim_to_use_3d, dim_to_ignore_3d,
                     data_mean_2d, data_std_2d, dim_to_use_2d, dim_to_ignore_2d,
                     current_step, encoder_inputs, decoder_outputs, current_epoch=0):
    """
    Generic method that evaluates performance of a list of batches.
    May be used to evaluate all actions or a single action.

    Args
      sess
      model
      data_mean_3d
      data_std_3d
      dim_to_use_3d
      dim_to_ignore_3d
      data_mean_2d
      data_std_2d
      dim_to_use_2d
      dim_to_ignore_2d
      current_step
      encoder_inputs
      decoder_outputs
      current_epoch
    Returns

      total_err
      joint_err
      step_time
      loss
    """

    n_joints = 17 if not(FLAGS.predict_14) else 14
    nbatches = len(encoder_inputs)

    # Loop through test examples
    all_dists, start_time, loss = [], time.time(), 0.
    log_every_n_batches = 100
    for i in range(nbatches):

        if current_epoch > 0 and (i+1) % log_every_n_batches == 0:
            print(
                "Working on test epoch {0}, batch {1} / {2}".format(current_epoch, i+1, nbatches))

        enc_in, dec_out = encoder_inputs[i], decoder_outputs[i]
        dp = 1.0  # dropout keep probability is always 1 at test time
        step_loss, loss_summary, poses3d = model.step(
            sess, enc_in, dec_out, dp, isTraining=False)
        loss += step_loss

        # denormalize
        enc_in = data_utils.unNormalizeData(
            enc_in,  data_mean_2d, data_std_2d, dim_to_ignore_2d)
        dec_out = data_utils.unNormalizeData(
            dec_out, data_mean_3d, data_std_3d, dim_to_ignore_3d)
        poses3d = data_utils.unNormalizeData(
            poses3d, data_mean_3d, data_std_3d, dim_to_ignore_3d)

        # Keep only the relevant dimensions
        dtu3d = np.hstack((np.arange(3), dim_to_use_3d)) if not(
            FLAGS.predict_14) else dim_to_use_3d

        dec_out = dec_out[:, dtu3d]
        poses3d = poses3d[:, dtu3d]

        assert dec_out.shape[0] == FLAGS.batch_size
        assert poses3d.shape[0] == FLAGS.batch_size

        if FLAGS.procrustes:
            # Apply per-frame procrustes alignment if asked to do so
            for j in range(FLAGS.batch_size):
                gt = np.reshape(dec_out[j, :], [-1, 3])
                out = np.reshape(poses3d[j, :], [-1, 3])
                _, Z, T, b, c = procrustes.compute_similarity_transform(
                    gt, out, compute_optimal_scale=True)
                out = (b*out.dot(T))+c

                poses3d[j, :] = np.reshape(
                    out, [-1, 17*3]) if not(FLAGS.predict_14) else np.reshape(out, [-1, 14*3])

        # Compute Euclidean distance error per joint
        # Squared error between prediction and expected output
        sqerr = (poses3d - dec_out)**2
        # Array with L2 error per joint in mm
        dists = np.zeros((sqerr.shape[0], n_joints))
        dist_idx = 0
        for k in np.arange(0, n_joints*3, 3):
            # Sum across X,Y, and Z dimenstions to obtain L2 distance
            dists[:, dist_idx] = np.sqrt(np.sum(sqerr[:, k:k+3], axis=1))
            dist_idx = dist_idx + 1

        all_dists.append(dists)
        assert sqerr.shape[0] == FLAGS.batch_size

    step_time = (time.time() - start_time) / nbatches
    loss = loss / nbatches

    all_dists = np.vstack(all_dists)

    # Error per joint and total for all passed batches
    joint_err = np.mean(all_dists, axis=0)
    total_err = np.mean(all_dists)

    return total_err, joint_err, step_time, loss


def sample():
    """Get samples from a model and visualize them"""

    actions = data_utils.define_actions(FLAGS.action)

    # Load camera parameters
    SUBJECT_IDS = [1, 5, 6, 7, 8, 9, 11]
    rcams = cameras.load_cameras(FLAGS.cameras_path, SUBJECT_IDS)

    # Load 3d data and load (or create) 2d projections
    train_set_3d, test_set_3d, data_mean_3d, data_std_3d, dim_to_ignore_3d, dim_to_use_3d, train_root_positions, test_root_positions = data_utils.read_3d_data(
        actions, FLAGS.data_dir, FLAGS.camera_frame, rcams, FLAGS.predict_14)

    if FLAGS.use_sh:
        train_set_2d, test_set_2d, data_mean_2d, data_std_2d, dim_to_ignore_2d, dim_to_use_2d = data_utils.read_2d_predictions(
            actions, FLAGS.data_dir)
    else:
        train_set_2d, test_set_2d, data_mean_2d, data_std_2d, dim_to_ignore_2d, dim_to_use_2d = data_utils.create_2d_data(
            actions, FLAGS.data_dir, rcams)
    print("done reading and normalizing data.")

    device_count = {"GPU": 0} if FLAGS.use_cpu else {"GPU": 1}
    with tf.Session(config=tf.ConfigProto(device_count=device_count)) as sess:
        # === Create the model ===
        print("Creating %d layers of %d units." %
              (FLAGS.num_layers, FLAGS.linear_size))
        batch_size = 128
        model = create_model(sess, actions, batch_size)
        print("Model loaded")

        for key2d in test_set_2d.keys():

            (subj, b, fname) = key2d
            print("Subject: {}, action: {}, fname: {}".format(subj, b, fname))

            # keys should be the same if 3d is in camera coordinates
            key3d = key2d if FLAGS.camera_frame else (
                subj, b, '{0}.h5'.format(fname.split('.')[0]))
            key3d = (subj, b, fname[:-3]) if (fname.endswith('-sh')
                                              ) and FLAGS.camera_frame else key3d

            enc_in = test_set_2d[key2d]
            n2d, _ = enc_in.shape
            dec_out = test_set_3d[key3d]
            n3d, _ = dec_out.shape
            assert n2d == n3d

            # Split into about-same-size batches
            enc_in = np.array_split(enc_in,  n2d // batch_size)
            dec_out = np.array_split(dec_out, n3d // batch_size)
            all_poses_3d = []

            for bidx in range(len(enc_in)):

                # Dropout probability 0 (keep probability 1) for sampling
                dp = 1.0
                _, _, poses3d = model.step(
                    sess, enc_in[bidx], dec_out[bidx], dp, isTraining=False)

                # denormalize
                enc_in[bidx] = data_utils.unNormalizeData(
                    enc_in[bidx], data_mean_2d, data_std_2d, dim_to_ignore_2d)
                dec_out[bidx] = data_utils.unNormalizeData(
                    dec_out[bidx], data_mean_3d, data_std_3d, dim_to_ignore_3d)
                poses3d = data_utils.unNormalizeData(
                    poses3d, data_mean_3d, data_std_3d, dim_to_ignore_3d)
                all_poses_3d.append(poses3d)

            # Put all the poses together
            enc_in, dec_out, poses3d = map(
                np.vstack, [enc_in, dec_out, all_poses_3d])

            # Convert back to world coordinates
            if FLAGS.camera_frame:
                N_CAMERAS = 4
                N_JOINTS_H36M = 32

                # Add global position back
                dec_out = dec_out + \
                    np.tile(test_root_positions[key3d], [1, N_JOINTS_H36M])

                # Load the appropriate camera
                subj, _, sname = key3d

                cname = sname.split('.')[1]  # <-- camera name
                # cams of this subject
                scams = {(subj, c+1): rcams[(subj, c+1)]
                         for c in range(N_CAMERAS)}
                # index of camera used
                scam_idx = [scams[(subj, c+1)][-1]
                            for c in range(N_CAMERAS)].index(cname)
                the_cam = scams[(subj, scam_idx+1)]  # <-- the camera used
                R, T, f, c, k, p, name = the_cam
                assert name == cname

                def cam2world_centered(data_3d_camframe):
                    data_3d_worldframe = cameras.camera_to_world_frame(
                        data_3d_camframe.reshape((-1, 3)), R, T)
                    data_3d_worldframe = data_3d_worldframe.reshape(
                        (-1, N_JOINTS_H36M*3))
                    # subtract root translation
                    return data_3d_worldframe - np.tile(data_3d_worldframe[:, :3], (1, N_JOINTS_H36M))

                # Apply inverse rotation and translation
                dec_out = cam2world_centered(dec_out)
                poses3d = cam2world_centered(poses3d)

    # Grab a random batch to visualize
    enc_in, dec_out, poses3d = map(np.vstack, [enc_in, dec_out, poses3d])
    idx = np.random.permutation(enc_in.shape[0])
    enc_in, dec_out, poses3d = enc_in[idx, :], dec_out[idx, :], poses3d[idx, :]

    # Visualize random samples
    import matplotlib.gridspec as gridspec

    # 1080p	= 1,920 x 1,080
    fig = plt.figure(figsize=(19.2, 10.8))

    gs1 = gridspec.GridSpec(5, 9)  # 5 rows, 9 columns
    gs1.update(wspace=-0.00, hspace=0.05)  # set the spacing between axes.
    plt.axis('off')

    subplot_idx, exidx = 1, 1
    nsamples = 15
    for i in np.arange(nsamples):

        # Plot 2d pose
        ax1 = plt.subplot(gs1[subplot_idx-1])
        p2d = enc_in[exidx, :]
        viz.show2Dpose(p2d, ax1)
        ax1.invert_yaxis()

        # Plot 3d gt
        ax2 = plt.subplot(gs1[subplot_idx], projection='3d')
        p3d = dec_out[exidx, :]
        viz.show3Dpose(p3d, ax2)

        # Plot 3d predictions
        ax3 = plt.subplot(gs1[subplot_idx+1], projection='3d')
        p3d = poses3d[exidx, :]
        viz.show3Dpose(p3d, ax3, lcolor="#9b59b6", rcolor="#2ecc71")

        exidx = exidx + 1
        subplot_idx = subplot_idx + 3

    plt.show()


def main(_):
    if FLAGS.sample:
        sample()
    else:
        train()


if __name__ == "__main__":
    tf.app.run()
